/*    Copyright 2010 Tobias Marschall
 *
 *    This file is part of MoSDi.
 *
 *    MoSDi is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    MoSDi is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with MoSDi.  If not, see <http://www.gnu.org/licenses/>.
 */

package mosdi.paa.apps;

import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Map.Entry;

import mosdi.fa.CDFA;
import mosdi.fa.FiniteMemoryTextModel;
import mosdi.paa.DeterministicEmitter;
import mosdi.paa.PAA;
import mosdi.paa.ProductMarkovChain;
import mosdi.util.Alphabet;
import mosdi.util.ArrayUtils;
import mosdi.util.HashableIntArray;
import mosdi.util.Iupac;
import mosdi.util.iterators.LexicographicalIterator;

import org.ujmp.core.Matrix;
import org.ujmp.core.MatrixFactory;

/** Class that calculates the distribution of the clump size with respect to a 
 *  finite memory text model for a pattern given as a CDFA. */
public class ClumpSizeCalculator {
	private CDFA cdfa;
	private ProductMarkovChain markovChain;
	private int maxEmission;
	private int patternLength;
	// number of steps the last execution of computeClumpSizeDistribution() took until convergence
	private int stepsToConvergence;

	/** Constructor. It is assumed that all strings accepted by the CDFA
	 *  have the given patternLength. If not, computations yield wrong results. */
	public ClumpSizeCalculator(FiniteMemoryTextModel textModel, CDFA cdfa, int patternLength) {
		this.stepsToConvergence = -1;
		this.cdfa = cdfa;
		this.markovChain = new ProductMarkovChain(cdfa, textModel);
		this.patternLength = patternLength;
		maxEmission = cdfa.getMaxOutput(); 
	}

	/** Base class for a PAA based on the ProductMarkovChain and emitting the number
	 *  of matches. */
	private abstract class InnerPAA extends PAA implements DeterministicEmitter {
		@Override
		public int getStateCount() { return markovChain.getStateCount(); }
		@Override
		public double transitionProbability(int state, int targetState) {
			return markovChain.getTransitionProbability(state, targetState);
		}
		@Override
		public double[][] stateValueStartDistribution() {
			double[][] result = new double[getStateCount()][getValueCount()];
			int startValue = getStartValue();
			for (int state=0; state<getStateCount(); ++state) {
				result[state][startValue] = markovChain.getInitialProbability(state); 
			}
			return result;
		}
		@Override
		public double emissionProbability(int state, int emission) {
			return getEmission(state)==emission?1.0:0.0;
		}
		@Override
		public int getStartState() { throw new UnsupportedOperationException();	}
		@Override
		public int getEmission(int state) {
			return cdfa.getStateOutput(markovChain.getAutomatonState(state));
		}
		@Override
		public int getEmissionCount() {
			return maxEmission+1;
		}
	}

	/** PAA to compute the clump start distribution. Here, the current value is
	 *  the number of steps since the last match. The value patternLength+1 has
	 *  a special meaning: it is used when a clump start has been observed. 
	 */
	private class ClumpStartPAA extends InnerPAA {
		@Override
		public int getStartValue() {
			return patternLength;
		}
		@Override
		public int getValueCount() { return patternLength+2; }
		@Override
		public int performOperation(int state, int value, int emission) {
			if (value==patternLength+1) value = 0;
			value = Math.min(patternLength, value+1); 
			if (emission>0){
				return (value==patternLength)?patternLength+1:0;
			}
			return value;
		}
	}

	/** Computes state distribution at clump start (states w.r.t. markovChain). */
	private double[] computeMCClumpStartDistribution() {
		ClumpStartPAA paa = new ClumpStartPAA();
		double[][] eq = paa.convergeToStateValueEquilibrium(1e-13, 10000000);
		double[] result = new double[paa.getStateCount()];
		double pTotal = 0.0;
		for (int state=0; state<paa.getStateCount(); ++state) {
			if (paa.getEmission(state)==0) continue;
			result[state]+=eq[state][patternLength+1];
			pTotal+=eq[state][patternLength+1];
		}
		if (pTotal<=0.0) {
			throw new IllegalStateException("Couldn't compute state distribution at clump starts, probably the excepted clump size is infinite.");
		}
		for (int state=0; state<paa.getStateCount(); ++state) {
			result[state]/=pTotal;
		}
		return result;
	}

	/** Computes state distribution at clump start (states w.r.t. CDFA). */
	public double[] computeClumpStartDistribution() {
		double[] startDistribution = computeMCClumpStartDistribution();
		double[] result = new double[cdfa.getStateCount()];
		for (int i=0; i<startDistribution.length; ++i) {
			result[markovChain.getAutomatonState(i)]+=startDistribution[i];
		}
		return result;
	}

	/** Specialized implementation of recurrences. Generic PAA implementation could be used, but that would be slower! */
	private double[] computeClumpSizeDistribution(int maxClumpSize, double accuracy, double[] clumpStartDistribution) {
		double[] result = new double[maxClumpSize+1];
		int stateCount = markovChain.getStateCount();
		// cache all emissions
		int[] emissions = new int[stateCount];
		for (int state=0; state<stateCount; ++state) {
			emissions[state] = cdfa.getStateOutput(markovChain.getAutomatonState(state));
		}
		// dist[state][l][n] is the probability of being in state 'state', having seen n 
		// matches and l non-match characters since the last match
		double[][][] dist1 = new double[stateCount][][];
		double[][][] dist2 = new double[stateCount][][];
		for (int state=0; state<stateCount; ++state) {
			dist1[state] = new double[patternLength-1][];
			dist2[state] = new double[patternLength-1][];
			if (emissions[state]>0) {
				dist1[state][0] = new double[maxClumpSize+1];
				dist2[state][0] = new double[maxClumpSize+1];
			} else {
				for (int l=0; l<patternLength-1; ++l) {
					dist1[state][l] = new double[maxClumpSize+1];
					dist2[state][l] = new double[maxClumpSize+1];
				}
			}
			if (emissions[state]<=maxClumpSize) {
				dist1[state][0][emissions[state]] = clumpStartDistribution[state];
			}
		}
		// main loop
		stepsToConvergence = 0;
		while (true) {
			// total probability mass in the table
			double pTotal = 0.0;
			// iterate over all states
			for (int state=0; state<stateCount; ++state) {
				int[] preimage = markovChain.getPreimage(state);
				double[] preimageProb = markovChain.getPreimageProbabilities(state);
				if (emissions[state]>0) {
					for (int targetClumpSize=emissions[state]; targetClumpSize<=maxClumpSize; ++targetClumpSize) {
						int nSource = targetClumpSize-emissions[state];
						double p = 0.0;
						for (int k=0; k<preimage.length; ++k) {
							// double transitionProb = preimageProb[k];
							int sourceState = preimage[k];
							for (double[] d : dist1[sourceState]) {
								p+=d[nSource]*preimageProb[k];
								if (emissions[sourceState]>0) break;
							}
						}
						dist2[state][0][targetClumpSize]=p;
						pTotal+=p;
					}
				} else {
					for (int clumpSize=0; clumpSize<=maxClumpSize; ++clumpSize) {
						for (int l=1; l<patternLength-1; ++l) {
							double p = 0.0;
							for (int k=0; k<preimage.length; ++k) {
								int sourceState = preimage[k];
								if ((emissions[sourceState]==0)||(l==1)) {
									p+=dist1[sourceState][l-1][clumpSize]*preimageProb[k];
								}
							}
							dist2[state][l][clumpSize]=p;
							pTotal+=p;
						}
						// allow probability mass to leave the table and gather in 
						// clump size distribution
						for (int k=0; k<preimage.length; ++k) {
							int sourceState = preimage[k];						
							if (dist1[sourceState][patternLength-2]!=null) {
								result[clumpSize]+=dist1[sourceState][patternLength-2][clumpSize]*preimageProb[k];
							}
						}
					}
				}
			}
			stepsToConvergence += 1;
			// swap distributions
			double[][][] h = dist1;
			dist1=dist2;
			dist2=h;
			// check whether the probability mass left in the table is small enough
			if (pTotal<accuracy) break;
		}
		return result;
	}

	/** Returns the number of steps the last execution of computeClumpSizeDistribution() took until convergence. */
	public int stepsToConvergence() {
		return this.stepsToConvergence;
	}

	public int getProductStateCount() {
		return markovChain.getStateCount();
	}

	/** Computes the (truncated) clump size distribution up to maxClumpSize.
	 *  @param accuracy Entries in returned array sum to >=1.0-accuracy.
	 */
	public double[] clumpSizeDistribution(int maxClumpSize, double accuracy) {
		double[] clumpStartDistribution = computeMCClumpStartDistribution();
		return computeClumpSizeDistribution(maxClumpSize,accuracy,clumpStartDistribution);
	}


	/** Returns true if p1 can be right-overlapped by p2 with the given shift. */
	private static boolean overlapPossible(int[] p1, int[] p2, int shift) {
		for (int i=0; i<p2.length; ++i) {
			if (i+shift>=p1.length) break;
			if (p1[i+shift]!=p2[i]) return false;
		}
		return true;
	}

	/** If set C is given, $C\setmins C\Sigma^+$ is returned. */
	private static Set<HashableIntArray> makePrefixFree(Set<HashableIntArray> set) {
		Set<HashableIntArray> result = new HashSet<HashableIntArray>();
		for (int length=0, n=0; n!=set.size(); length+=1) {
			for (HashableIntArray s : set) {
				if (s.array().length!=length) continue;
				n+=1;
				boolean found = false;
				for (int i=0; i<length; ++i) {
					HashableIntArray t = new HashableIntArray(Arrays.copyOfRange(s.array(),0,i+1));
					if (result.contains(t)) {
						found = true;
						break;
					}
				}
				if (!found) result.add(s);
			}
		}
		return result;
	}

	public static void printOverlapMatrix(PrintStream ps, Alphabet alphabet, FiniteMemoryTextModel textModel, List<int[]> patterns) {
		if (alphabet.size()!=textModel.getAlphabetSize()) throw new IllegalArgumentException("Alphabet size mismatch");
		int length = patterns.get(0).length;
		for (int[] pattern : patterns) {
			if (pattern.length!=length) throw new IllegalArgumentException("All patterns must have same length.");
		}
		Matrix matrix = overlapMatrix(textModel, patterns, length);
		int n = patterns.size();
		for (int source=0; source<n; ++source) {
			for (int sourceState=0; sourceState<textModel.getStateCount(); ++sourceState) {
				ps.print(String.format("(%s,%d) |", alphabet.buildString(patterns.get(source)),sourceState));
				for (int target=0; target<n; ++target) {
					for (int targetState=0; targetState<textModel.getStateCount(); ++targetState) {
						int i = sourceState*n+source;
						int j = targetState*n+target;
						ps.print(String.format(" %04f",matrix.getAsDouble(i,j)));
					}
				}
				ps.println();
			}
		}
	}

	private static Matrix overlapMatrix(FiniteMemoryTextModel textModel, List<int[]> patterns, int length) {
		int n = patterns.size();
		int m = textModel.getStateCount() * patterns.size();
		Matrix result = MatrixFactory.zeros(m,m);
		for (int source=0; source<patterns.size(); ++source) {
			int[] sourcePattern = patterns.get(source);
			// Maps possible extensions onto string indices.
			Map<HashableIntArray,Integer> extensions = new HashMap<HashableIntArray,Integer>();
			for (int target=0; target<patterns.size(); ++target) {
				int[] targetPattern = patterns.get(target);
				for (int shift=1; shift<length; ++shift) {
					if (overlapPossible(sourcePattern, targetPattern, shift)) {
						extensions.put(new HashableIntArray(Arrays.copyOfRange(targetPattern, targetPattern.length-shift,length)), target);
					}
				}
			}
			Set<HashableIntArray> p = makePrefixFree(extensions.keySet());
			for (HashableIntArray u : p) {
				int[] extension = u.array();
				int target = extensions.get(u);
				for (int sourceState=0; sourceState<textModel.getStateCount(); ++sourceState) {
					int[] prefix = Arrays.copyOf(sourcePattern, extension.length);
					int[] infix  = Arrays.copyOfRange(sourcePattern, extension.length, sourcePattern.length);
					double[] d = textModel.statewiseProductionProbability(sourceState, prefix);
					ArrayUtils.normalize(d);
					for (int targetState=0; targetState<textModel.getStateCount(); ++targetState) {
						double[] d1 = textModel.statewiseProductionProbability(targetState, infix);
						ArrayUtils.normalize(d1, d[targetState]);
						double[] d2 = textModel.statewiseProductionProbability(d1, extension);
						int i = sourceState*n+source;
						int j = targetState*n+target;
						result.setAsDouble(result.getAsDouble(i,j)+ArrayUtils.sum(d2), i, j);
					}
				}
			}
		}
		return result;
	}

	/** Directly computes the expected clump size (without computing the whole distribution).  
	 *  @param patterns All patterns must have the same length and be given over the same alphabet
	 *                  as the textModel.
	 *  @param weights  Array containing a weight for each pattern.
	 */
	public static double expectedClumpSize(FiniteMemoryTextModel textModel, List<int[]> patterns, double[] weights) {
		// 0) Check input
		if (patterns.size()==0) throw new IllegalArgumentException();
		int length = patterns.get(0).length;
		for (int[] pattern : patterns) {
			if (pattern.length!=length) throw new IllegalArgumentException("Pattern must have same lengths.");
		}
		int n = patterns.size();
		int m = textModel.getStateCount() * patterns.size();
		// 1) Compute equilibrium (row)vector
		Matrix p = MatrixFactory.zeros(1,m);
		double sum = 0.0;
		double[] eq = textModel.getEquilibriumDistribution();
		for (int pattern=0; pattern<n; ++pattern) {
			for (int state=0; state<textModel.getStateCount(); ++state) {
				double prob = eq[state] * textModel.productionProbability(state, patterns.get(pattern));
				p.setAsDouble(prob, 0, state*n+pattern);
				sum += prob;
			}
		}
		// normalize to 1.0
		for (int i=0; i<m; ++i)	p.setAsDouble(p.getAsDouble(0,i)/sum, 0, i);
		Matrix overlapMatrix = overlapMatrix(textModel, patterns, length);
		Matrix ones = MatrixFactory.ones(m,1);
		if (weights==null) {
			return 1.0/(1.0 - p.mtimes(overlapMatrix).mtimes(ones).doubleValue());
		} else {
			Matrix w = MatrixFactory.zeros(m,1);
			for (int state=0; state<textModel.getStateCount(); ++state) {
				for (int pattern=0; pattern<n; ++pattern) {
					w.setAsDouble(weights[pattern], state*n+pattern, 0);
				}
			}
			return p.mtimes(w).doubleValue()/(1.0 - p.mtimes(overlapMatrix).mtimes(ones).doubleValue());
		}
	}

	/** Directly computes the expected clump size (without computing the whole distribution).  
	 *  @param patterns All patterns must have the same length and be given over the same alphabet
	 *                  as the textModel.
	 */
	public static double expectedClumpSize(FiniteMemoryTextModel textModel, List<int[]> patterns) {
		return expectedClumpSize(textModel, patterns, null);
	}

	/** Directly computes the expected clump size (without computing the whole distribution).  
	 *  @param textModel A text model over the DNA alphabet.
	 *  @param patterns All patterns must have the same length and be given over the IUPAC
	 *                  alphabet, see Alphabet.getIupacAlphabet().
	 */
	public static double expectedClumpSize(FiniteMemoryTextModel textModel, int[] iupacPattern, boolean considerReverse) {
		Alphabet dnaAlphabet = Alphabet.getDnaAlphabet();
		// maps strings to their multiplicity
		Map<String,Integer> patternMap = new HashMap<String,Integer>();
		LexicographicalIterator iterator = Iupac.patternInstanceIterator(iupacPattern);
		while (iterator.hasNext()) patternMap.put(dnaAlphabet.buildString(iterator.next()),1);
		if (considerReverse) {
			iterator = Iupac.patternInstanceIterator(Iupac.reverseComplementary(iupacPattern));
			while (iterator.hasNext()) {
				String s = dnaAlphabet.buildString(iterator.next());
				if (patternMap.containsKey(s)) {
					patternMap.put(s, 2);
				}
				else patternMap.put(s, 1);
			}
		}
		List<int[]> patterns = new ArrayList<int[]>();
		double[] weights = new double[patternMap.size()];
		int n = 0;
		for (Entry<String,Integer> e : patternMap.entrySet()) {
			patterns.add(dnaAlphabet.buildIndexArray(e.getKey()));
			weights[n++] = e.getValue();
		}
		return expectedClumpSize(textModel, patterns, considerReverse?weights:null);
	}

}
